import os
## Set directory
os.chdir('/hpc/group/pbenfeylab/CheWei/CW_data/genesys')
import networkx as nx
from genesys_evaluate import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import warnings
# Suppress all warning messages
warnings.filterwarnings("ignore", category=DeprecationWarning)
## Conda Env pytorch-gpu on DCC
print(torch.__version__)
print(sc.__version__)
1.13.0.post200 1.9.1
## Genes considered/used (shared among samples)
gene_list = pd.read_csv('./gene_list_1108.csv')
with open("./genesys_root_data.pkl", 'rb') as file_handle:
data = pickle.load(file_handle)
batch_size = 2000
dataset = Root_Dataset(data['X_test'], data['y_test'])
loader = DataLoader(dataset,
batch_size = batch_size,
shuffle = True, drop_last=True)
input_size = data['X_train'].shape[1]
## 10 cell types
output_size = 10
embedding_dim = 256
hidden_dim = 256
n_layers = 2
device = "cpu"
path = "./"
model = ClassifierLSTM(input_size, output_size, embedding_dim, hidden_dim, n_layers).to(device)
model.load_state_dict(torch.load(path+"best_ALL_1130_continue.pth", map_location=torch.device('cpu')))
model = model
model.eval()
ClassifierLSTM(
(fc1): Sequential(
(0): Linear(in_features=17513, out_features=256, bias=True)
(1): Dropout(p=0.2, inplace=False)
(2): GaussianNoise()
)
(fc): Sequential(
(0): ReLU()
(1): Linear(in_features=512, out_features=512, bias=True)
(2): ReLU()
(3): Linear(in_features=512, out_features=10, bias=True)
)
(lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
(dropout): Dropout(p=0.2, inplace=False)
(b_to_z): DBlock(
(fc1): Linear(in_features=512, out_features=256, bias=True)
(fc2): Linear(in_features=512, out_features=256, bias=True)
(fc_mu): Linear(in_features=256, out_features=512, bias=True)
(fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
)
(bz2_infer_z1): DBlock(
(fc1): Linear(in_features=1024, out_features=256, bias=True)
(fc2): Linear(in_features=1024, out_features=256, bias=True)
(fc_mu): Linear(in_features=256, out_features=512, bias=True)
(fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
)
(z1_to_z2): DBlock(
(fc1): Linear(in_features=512, out_features=256, bias=True)
(fc2): Linear(in_features=512, out_features=256, bias=True)
(fc_mu): Linear(in_features=256, out_features=512, bias=True)
(fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
)
(z_to_x): Decoder(
(fc1): Linear(in_features=512, out_features=256, bias=True)
(fc2): Linear(in_features=256, out_features=256, bias=True)
(fc3): Linear(in_features=256, out_features=17513, bias=True)
)
)
classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium', 'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
class2num = {c: i for (i, c) in enumerate(classes)}
num2class = {i: c for (i, c) in enumerate(classes)}
cts = ['Atrichoblast','Trichoblast','Cortex','Endodermis','Pericycle','Procambium','Xylem','Phloem','Lateral Root Cap','Columella']
ctw = np.zeros((len(cts), 17513, 17513))
## number of cells sampled from the atlas
batch_size = 2000
## GRN for the transition t7 to t9
for ct in cts:
print(ct)
cws = np.zeros((len(loader), 17513, 17513))
with torch.no_grad():
for i, sample in enumerate(loader):
x = sample['x'].to(device)
y = sample['y'].to(device)
y_label = [num2class[i] for i in y.tolist()]
pred_h = model.init_hidden(batch_size)
tfrom = model.generate_next(x, pred_h, 7).to('cpu').detach().numpy()
cfrom = tfrom[np.where(np.array(y_label)==ct)[0],:]
pred_h = model.init_hidden(batch_size)
tto = model.generate_next(x, pred_h, 9).to('cpu').detach().numpy()
cto = tto[np.where(np.array(y_label)==ct)[0],:]
cw = torch.linalg.lstsq(torch.tensor(cfrom), torch.tensor(cto)).solution.detach().numpy()
cws[i] = cw
## Calculate mean across number of repeats
cwm = np.mean(cws, axis=0)
ctw[cts.index(ct)] = cwm
Atrichoblast Trichoblast Cortex Endodermis Pericycle Procambium Xylem Phloem Lateral Root Cap Columella
# Save the array to disk
np.save('genesys_ctw_t7-t9.npy', ctw)
ctw = np.load('genesys_ctw_t7-t9.npy')
## Calculate z-scores
ctw_z = np.zeros((len(cts), 17513, 17513))
for i in range(len(cts)):
ctw_z[i] = (ctw[i] - np.mean(ctw[i])) / np.std(ctw[i])
## Filtering based on z-scores (with no weights)
ctw_f = np.zeros((len(cts), 17513, 17513))
## z-score threshold (keep values > mean + threshold*std)
threshold=3
for i in range(len(cts)):
ctw_f[i] = np.abs(ctw_z[i]) > threshold
wanted_TFs = pd.read_csv("./Kay_TF_thalemine_annotations.csv")
## Make TF names unique and assign preferred names
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G33880"]="WOX9"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G45160"]="SCL27"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G04410"]="NAC78"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT3G29035"]="ORS1"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G02540"]="ZHD3"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT3G16500"]="IAA26"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G09740"]="HAG5"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT4G24660"]="ZHD2"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G46880"]="HDG5"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G28420"]="RLT1"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G14580"]="BLJ"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT3G45260"]="BIB"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G02070"]="RVN"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G28160"]="FIT"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G68360"]="GIS3"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G20640"]="NLP4"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G05550"]="VFP5"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT3G59470"]="FRF1"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G15150"]="HAT7"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G14750"]="WER"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G75710"]="BRON"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G74500"]="TMO7"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT2G12646"]="RITF1"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT3G48100"]="ARR5"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT4G16141"]="GATA17L"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G65640"]="NFL"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G62700"]="VND5"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT4G36160"]="VND2"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G66300"]="VND3"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT1G12260"]="VND4"
wanted_TFs['Name'][wanted_TFs['GeneID']=="AT5G62380"]="VND6"
pd.Series(wanted_TFs['Name']).value_counts().head(5)
NAC001 1 PRE5 1 MYB118 1 MYB21 1 MYB0 1 Name: Name, dtype: int64
TFidx = []
for i in wanted_TFs['GeneID']:
if i in gene_list['features'].tolist():
TFidx.append(np.where(gene_list['features']==i)[0][0])
TFidx = np.sort(np.array(TFidx))
def network(i):
## No weights
adj_nw = ctw_f[i]
## Weighted
adj = ctw[i]*ctw_f[i]
## TF only
adj = adj[np.ix_(TFidx,TFidx)]
adj_nw = adj_nw[np.ix_(TFidx,TFidx)]
## Remove no connect
regidx = np.sort(np.array(pd.Series(np.where(adj_nw==True)[0]).value_counts().index[pd.Series(np.where(adj_nw==True)[0]).value_counts()>=1]))
taridx = np.sort(np.array(pd.Series(np.where(adj_nw==True)[1]).value_counts().index[pd.Series(np.where(adj_nw==True)[1]).value_counts()>=1]))
## Reciprocol
keepidx = np.sort(np.array(list(set(regidx).intersection(taridx))))
#keepidx = np.sort(np.array(list(set(regidx).union(taridx))))
TFID = np.array(gene_list['features'][TFidx])[keepidx].tolist()
## TF name to keep
TFname = []
for i in np.array(gene_list['features'][TFidx])[keepidx]:
TFname.append(wanted_TFs['Name'][np.where(wanted_TFs['GeneID']==i)[0][0]])
adj = adj[np.ix_(keepidx,keepidx)]
# Create a NetworkX graph for non-directed edges
G = nx.Graph() # supports directed edges and allows for multiple edges between the same pair of nodes
# Add nodes to the graph
num_nodes = adj.shape[0]
for i, name in enumerate(TFname):
G.add_node(i, name=name)
# Add edges to the graph with weights
for i in range(num_nodes):
for j in range(num_nodes):
weight = adj[i, j]
if weight != 0:
G.add_edge(i, j, weight=abs(weight), distance=1/abs(weight))
## Measures the extent to which how close a node is to all other nodes in the network, considering the shortest paths or geodesic distances between nodes
closeness_centrality = nx.closeness_centrality(G, distance='distance')
## Measures the extent to which a node that are not only well-connected but also connected to other well-connected nodes.
eigenvector_centrality = nx.eigenvector_centrality(G)
# Create a NetworkX graph for diected edges
G = nx.MultiDiGraph() # supports directed edges and allows for multiple edges between the same pair of nodes
# Add nodes to the graph
num_nodes = adj.shape[0]
for i, name in enumerate(TFname):
G.add_node(i, name=name)
# Add edges to the graph with weights
for i in range(num_nodes):
for j in range(num_nodes):
weight = adj[i, j]
if weight != 0:
G.add_edge(i, j, weight=weight)
## Measures the number of connections (edges) each node has
degree_centrality = nx.degree_centrality(G)
# Calculate outgoing centrality
out_centrality = nx.out_degree_centrality(G)
# Calculate incoming centrality
in_centrality = nx.in_degree_centrality(G)
## Measures the extent to which a node lies on the shortest paths between other nodes.
betweenness_centrality = nx.betweenness_centrality(G, weight='weight')
## Non_Reciprocal Out centrality
# Visualize the graph
pos = nx.spring_layout(G) # Positions of the nodes
# Node colors based on weighted betweenness centrality
node_colors = [out_centrality[node] for node in G.nodes()]
# Node sizes based on weighted betweenness centrality
node_sizes = [out_centrality[node] * 1000 for node in G.nodes()]
# Get the edge weights as a dictionary
edge_weights = nx.get_edge_attributes(G, 'weight')
edge_colors = ['red' if weight > 0 else 'blue' for (_, _, weight) in G.edges(data='weight')]
# Scale the edge weights to desired linewidths
max_weight = max(edge_weights.values())
edge_widths = [float(edge_weights[edge]) / max_weight for edge in G.edges]
# Draw the graph
nx.draw(G, pos=pos, node_color=node_colors, node_size=node_sizes, with_labels=False, width=edge_widths, edge_color=edge_colors)
# Add node labels
labels = {node: G.nodes[node]['name'] for node in G.nodes}
nx.draw_networkx_labels(G, pos=pos, labels=labels, font_size=8)
# Add a colorbar to show the weighted betweenness centrality color mapping
sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(node_colors), vmax=max(node_colors)))
sm.set_array([])
plt.colorbar(sm)
# Show the plot
plt.show()
dc = pd.DataFrame.from_dict(degree_centrality, orient='index', columns=['degree_centrality'])
oc = pd.DataFrame.from_dict(out_centrality, orient='index', columns=['out_centrality'])
ic = pd.DataFrame.from_dict(in_centrality, orient='index', columns=['in_centrality'])
bc = pd.DataFrame.from_dict(betweenness_centrality, orient='index', columns=['betweenness_centrality'])
cc = pd.DataFrame.from_dict(closeness_centrality, orient='index', columns=['closeness_centrality'])
ec = pd.DataFrame.from_dict(eigenvector_centrality, orient='index', columns=['eigenvector_centrality'])
df = pd.concat([dc,oc,ic,bc,cc,ec], axis=1)
df.index =TFname
df = df.sort_values('betweenness_centrality', ascending=False)
return(df)
atri = network(0)
tri = network(1)
cor = network(2)
end = network(3)
per = network(4)
pro = network(5)
xyl = network(6)
phl = network(7)
lrc = network(8)
col = network(9)